{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Messages Summarization\n",
    "![Chain](images/summarization.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## \"Standard\" chatbot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_core.messages import HumanMessage, SystemMessage\n",
    "from langgraph.graph import START, END, StateGraph\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from IPython.display import Image, display\n",
    "\n",
    "# OPENAI_API_KEY environment variable must be set\n",
    "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
    "\n",
    "# Defining Schema\n",
    "##################################################################################\n",
    "class State(MessagesState):\n",
    "    question: str\n",
    "    answer: str\n",
    "\n",
    "\n",
    "# Defining Agent's node\n",
    "##################################################################################\n",
    "\n",
    "# System message\n",
    "chatbot_system_message = SystemMessage(content=(\"\"\"\n",
    "You are a helpful and knowledgeable chatbot assistant. \n",
    "Your goal is to provide clear and accurate answers to user questions based on the information they provide. \n",
    "Stay focused, concise, and ensure your responses are relevant to the context of the conversation. \n",
    "If you don’t have enough information, ask for clarification.”\n",
    "\"\"\"))\n",
    "\n",
    "# Node\n",
    "def chatbot(state: State) -> State:\n",
    "    question = HumanMessage(content=state.get(\"question\", \"\"))\n",
    "    response = llm.invoke([chatbot_system_message] + state[\"messages\"] + [question]);\n",
    "\n",
    "    return State(\n",
    "        messages = [question, response],\n",
    "        question = state.get(\"question\", None),\n",
    "        answer = response.content\n",
    "    )\n",
    "\n",
    "\n",
    "# Defining Graph\n",
    "##################################################################################\n",
    "\n",
    "builder = StateGraph(State)\n",
    "builder.add_node(\"chatbot\", chatbot)\n",
    "\n",
    "\n",
    "# Define edges: these determine how the control flow moves\n",
    "builder.add_edge(START, \"chatbot\")\n",
    "builder.add_edge(\"chatbot\", END)\n",
    "\n",
    "memory = MemorySaver()\n",
    "chatbot_graph = builder.compile(checkpointer=memory)\n",
    "\n",
    "# Show\n",
    "display(Image(chatbot_graph.get_graph(xray=True).draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Removing messages\n",
    "\n",
    "1) we use MessagesState\n",
    "2) MessageState has a built-in list of messages (\"messages\" key)\n",
    "3) Key \"messages\" has a built-in reducer \"add_messages\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Annotated\n",
    "from langgraph.graph.message import add_messages\n",
    "\n",
    "class MessagesState(TypedDict):\n",
    "    messages: Annotated[list[AnyMessage], add_messages]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import HumanMessage, SystemMessage, AIMessage\n",
    "from langgraph.graph.message import add_messages\n",
    "\n",
    "messages = [\n",
    "    HumanMessage(content=\"Message 1\", id=\"1\"),\n",
    "    AIMessage(content=\"Message 2\", id=\"2\"),\n",
    "    HumanMessage(content=\"Message 3\", id=\"3\"),\n",
    "    AIMessage(content=\"Message 4\", id=\"4\"),\n",
    "    HumanMessage(content=\"Message 5\", id=\"5\"),\n",
    "    AIMessage(content=\"Message 6\", id=\"6\")\n",
    "]\n",
    "\n",
    "new_message = HumanMessage(content=\"Message 7\", id=\"7\")\n",
    "\n",
    "# Test\n",
    "messages = add_messages(messages , new_message)\n",
    "messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import RemoveMessage\n",
    "\n",
    "# Isolate messages to delete\n",
    "delete_messages = [RemoveMessage(id=m.id) for m in messages[:-2]]\n",
    "delete_messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = add_messages(messages , delete_messages)\n",
    "messages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summarization chatbot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from langgraph.graph import StateGraph, START, END\n",
    "\n",
    "# Defining Schema\n",
    "##################################################################################\n",
    "class SummaryState(State):\n",
    "    summary: str\n",
    "\n",
    "# Nodes\n",
    "def chatbot(state: SummaryState) -> SummaryState:\n",
    "    summary = state.get(\"summary\", \"\") # getting summary if it exists\n",
    "\n",
    "    # If there is summary, then we add it\n",
    "    if summary:\n",
    "        # define summary as SystemMessage\n",
    "        summary_message = SystemMessage(content=(f\"\"\"\n",
    "        Summary of Conversation:\n",
    "\n",
    "        {summary}\n",
    "        \"\"\"))\n",
    "\n",
    "        messages_with_summary = [summary_message] + state[\"messages\"]\n",
    "    \n",
    "    else:\n",
    "        messages_with_summary = state[\"messages\"]\n",
    "\n",
    "\n",
    "    question = HumanMessage(content=state.get(\"question\", \"\"))\n",
    "\n",
    "    response = llm.invoke([chatbot_system_message] + messages_with_summary + [question]);\n",
    "\n",
    "    return SummaryState(\n",
    "        messages = [question, response],\n",
    "        question = state.get(\"question\", None),\n",
    "        answer = response.content,\n",
    "        summary = state.get(\"summary\", None)\n",
    "    )\n",
    "\n",
    "\n",
    "def summarize(state: SummaryState) -> SummaryState:\n",
    "    summary = state.get(\"summary\", \"\")\n",
    "    # no system message\n",
    "    # the order of components is important\n",
    "\n",
    "    if summary:\n",
    "        summary_message = HumanMessage(content=(f\"\"\"\n",
    "            Expand the summary below by incorporating the above conversation while preserving context, key points, and \n",
    "            user intent. Rework the summary if needed. Ensure that no critical information is lost and that the \n",
    "            conversation can continue naturally without gaps. Keep the summary concise yet informative, removing \n",
    "            unnecessary repetition while maintaining clarity.\n",
    "            \n",
    "            Only return the updated summary. Do not add explanations, section headers, or extra commentary.\n",
    "\n",
    "            Existing summary:\n",
    "\n",
    "            {summary}\n",
    "            \"\"\")\n",
    "        )\n",
    "        \n",
    "    else:\n",
    "        summary_message = HumanMessage(content=\"\"\"\n",
    "        Summarize the above conversation while preserving full context, key points, and user intent. Your response \n",
    "        should be concise yet detailed enough to ensure seamless continuation of the discussion. Avoid redundancy, \n",
    "        maintain clarity, and retain all necessary details for future exchanges.\n",
    "\n",
    "        Only return the summarized content. Do not add explanations, section headers, or extra commentary.\n",
    "        \"\"\")\n",
    "\n",
    "    # Add prompt to our history\n",
    "    messages = state[\"messages\"] + [summary_message]\n",
    "    response = llm.invoke(messages)\n",
    "    \n",
    "    # Delete all but the 2 most recent messages\n",
    "    delete_messages = [RemoveMessage(id=m.id) for m in state[\"messages\"][:-2]]\n",
    "    \n",
    "    return SummaryState(\n",
    "        messages = delete_messages,\n",
    "        question = state.get(\"question\", None),\n",
    "        answer = state.get(\"answer\", None),\n",
    "        summary = response.content\n",
    "    )\n",
    "\n",
    "\n",
    "# Edges\n",
    "\n",
    "# Determine whether to end or summarize the conversation\n",
    "def should_summarize(state: SummaryState):\n",
    "    messages = state[\"messages\"]\n",
    "    \n",
    "    if len(messages) > 2:\n",
    "        return \"summarize\"\n",
    "    \n",
    "    return END\n",
    "\n",
    "\n",
    "# Graph\n",
    "workflow = StateGraph(SummaryState)\n",
    "workflow.add_node(chatbot)\n",
    "workflow.add_node(summarize)\n",
    "\n",
    "workflow.add_edge(START, \"chatbot\")\n",
    "workflow.add_conditional_edges(\"chatbot\", should_summarize)\n",
    "workflow.add_edge(\"summarize\", END)\n",
    "\n",
    "\n",
    "memory = MemorySaver()\n",
    "graph = workflow.compile(checkpointer=memory)\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "thread_id = \"1\"\n",
    "config = {\"configurable\": {\"thread_id\": thread_id}}\n",
    "\n",
    "graph.invoke(State(question=\"Hi, I’m working on a Python project, and I’m stuck with handling API responses.\"), config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke(State(question=\"Sorry what was my previous question?\"), config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke(State(question=\"Ahh, yeah right! So I’m mostly struggling with parsing JSON responses. Sometimes the structure isn’t what I expect, and it breaks my code.\"), config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke(State(question=\"Got it! That helps a lot. What would be the best way to handle deeply nested JSON data when I only need a few specific values?\"), config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke(State(question=\"How is the weather outside?\"), config)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
